{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-12-02T09:08:15.737618Z",
     "start_time": "2024-12-02T09:08:09.867132Z"
    }
   },
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-02T09:08:59.704893Z",
     "start_time": "2024-12-02T09:08:59.686719Z"
    }
   },
   "cell_type": "code",
   "source": [
    "data_dir=\"C:/Users/Lenovo/Desktop/深度/实验三数据集/车辆分类数据集\"\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((128, 128)),\n",
    "    transforms.ToTensor(),\n",
    "])\n",
    "\n",
    "dataset = datasets.ImageFolder(data_dir, transform=transform)\n",
    "\n",
    "dataloader = DataLoader(dataset, batch_size=32, shuffle=True,num_workers=4)\n"
   ],
   "id": "bb6ef7b820e2bba2",
   "outputs": [],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-02T09:09:00.085333Z",
     "start_time": "2024-12-02T09:09:00.078402Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.layer1 = nn.Sequential(\n",
    "            nn.Conv2d(3, 32, kernel_size=3),\n",
    "            nn.BatchNorm2d(32),\n",
    "            nn.ReLU(inplace=True),\n",
    "        )\n",
    "        self.fc=nn.Linear(in_features=32, out_features=3)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.layer1(x)\n",
    "        #print(x.shape)\n",
    "        x= F.avg_pool2d(x, 126)\n",
    "        #print(x.shape)\n",
    "        x= x.squeeze()\n",
    "        #print(x.shape)\n",
    "        x = self.fc(x)\n",
    "        return x\n"
   ],
   "id": "3465b35c137e7f50",
   "outputs": [],
   "execution_count": 3
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-02T09:09:00.636561Z",
     "start_time": "2024-12-02T09:09:00.500903Z"
    }
   },
   "cell_type": "code",
   "source": [
    "net = Net().to(device)\n",
    "criterion= nn.CrossEntropyLoss()\n",
    "optimizer= torch.optim.Adam(net.parameters(),lr=0.01)"
   ],
   "id": "ce76926d050e19bb",
   "outputs": [],
   "execution_count": 4
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-02T09:09:00.959436Z",
     "start_time": "2024-12-02T09:09:00.949464Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def train_net(net,dataloader):\n",
    "    net.train()\n",
    "    train_batches=len(dataloader)\n",
    "    \n",
    "    for epoch in range(10):\n",
    "        total_loss=0\n",
    "        correct=0\n",
    "        sample_num=0\n",
    "        for batch_idx, (data, target) in enumerate(dataloader):\n",
    "            data=data.to(device).float()\n",
    "            target=target.to(device).long()\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            output=net(data)\n",
    "            loss=criterion(output,target)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            #print(\"Batch %d, Loss: %.4f\"%(batch_idx,loss.item()))\n",
    "            \n",
    "            total_loss+=loss.item()\n",
    "            prediction=torch.argmax(output,1)\n",
    "            correct += (prediction==target).sum().item()\n",
    "            sample_num+=len(prediction)\n",
    "            \n",
    "        \n",
    "        loss=total_loss/train_batches\n",
    "        acc=correct/train_batches\n",
    "        print('Loss: {:.4f} Acc: {:.4f}'.format(loss,acc))\n",
    "    return loss,acc\n",
    "        "
   ],
   "id": "ab07bcf21be67a80",
   "outputs": [],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-02T09:13:54.182549Z",
     "start_time": "2024-12-02T09:09:01.624359Z"
    }
   },
   "cell_type": "code",
   "source": "train_net(net,dataloader)",
   "id": "721477d76ab90c3b",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 0.8475 Acc: 19.6977\n",
      "Loss: 0.7733 Acc: 21.5116\n",
      "Loss: 0.7185 Acc: 23.3721\n",
      "Loss: 0.6963 Acc: 23.3488\n",
      "Loss: 0.6993 Acc: 23.0930\n",
      "Loss: 0.6593 Acc: 24.3488\n",
      "Loss: 0.6633 Acc: 23.9070\n",
      "Loss: 0.6357 Acc: 24.5581\n",
      "Loss: 0.6292 Acc: 24.9767\n",
      "Loss: 0.6008 Acc: 25.2093\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0.6007900854875875, 25.209302325581394)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 6
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "aaa461d01ed87dfa"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
